iT邦幫忙

2024 iThome 鐵人賽

DAY 10
0

今天是第十天可以寫一個Lstm模型訓練來看我的斑馬魚座標的predict,以下是程式碼

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense

# 生成模擬的斑馬魚座標數據
def generate_data(timesteps, num_samples):
    X = []
    y = []
    for i in range(num_samples):
        start_x = np.random.rand()
        start_y = np.random.rand()
        x_coords = np.linspace(start_x, start_x + np.random.rand(), timesteps)
        y_coords = np.linspace(start_y, start_y + np.random.rand(), timesteps)
        sequence = np.column_stack((x_coords, y_coords))
        X.append(sequence[:-1])
        y.append(sequence[-1])
    return np.array(X), np.array(y)

# 設定參數
timesteps = 10  # 使用的時間步長
num_samples = 1000  # 數據樣本數量

# 生成訓練數據
X_train, y_train = generate_data(timesteps, num_samples)

# 建立LSTM模型
model = Sequential()
model.add(LSTM(50, activation='relu', input_shape=(timesteps-1, 2)))
model.add(Dense(2))  # 輸出層,輸出兩個座標 (x, y)
model.compile(optimizer='adam', loss='mse')

# 訓練模型
model.fit(X_train, y_train, epochs=200, verbose=1)

# 測試模型
X_test, y_test = generate_data(timesteps, 10)
predictions = model.predict(X_test)

# 顯示結果
for i in range(10):
    print(f"真實座標: {y_test[i]}, 預測座標: {predictions[i]}")

1. 依賴庫的匯入

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
  • numpy (np): 這是Python的一個數值計算庫,用於數組操作。
  • tensorflow (tf): 一個深度學習框架,用於構建和訓練神經網絡。
  • Sequential: TensorFlow中用於構建神經網絡的線性堆疊模型。
  • LSTM: 一種特殊的RNN(循環神經網絡),適合處理和預測基於時間序列的數據。
  • Dense: 全連接層(也稱為密集層),是神經網絡中基本的神經元層。

2. 數據生成函數

def generate_data(timesteps, num_samples):
    X = []
    y = []
    for i in range(num_samples):
        start_x = np.random.rand()
        start_y = np.random.rand()
        x_coords = np.linspace(start_x, start_x + np.random.rand(), timesteps)
        y_coords = np.linspace(start_y, start_y + np.random.rand(), timesteps)
        sequence = np.column_stack((x_coords, y_coords))
        X.append(sequence[:-1])
        y.append(sequence[-1])
    return np.array(X), np.array(y)
  • generate_data: 這個函數用於生成模擬的斑馬魚座標數據。
    • timesteps: 時間步長,決定了每個序列中數據點的數量。
    • num_samples: 樣本數量,決定了生成多少序列。
    • start_x, start_y: 起始點的隨機 x 和 y 座標。
    • x_coords, y_coords: 座標序列,從起點開始,按照隨機的增量生成。
    • sequence: 每個序列包含時間步內的所有座標對 (x, y)
    • X: 訓練輸入數據,包含每個序列的所有座標點(除了最後一個)。
    • y: 訓練目標數據,對應每個序列的最後一個座標點(預測目標)。

3. 設定參數

timesteps = 10  # 使用的時間步長
num_samples = 1000  # 數據樣本數量
  • timesteps: 每個樣本包含的時間步長,這裡設定為10。
  • num_samples: 生成1000個樣本數據。

4. 生成訓練數據

X_train, y_train = generate_data(timesteps, num_samples)
  • 生成訓練數據,X_train 包含訓練輸入,y_train 包含訓練目標。

5. 建立LSTM模型

model = Sequential()
model.add(LSTM(50, activation='relu', input_shape=(timesteps-1, 2)))
model.add(Dense(2))  # 輸出層,輸出兩個座標 (x, y)
model.compile(optimizer='adam', loss='mse')
  • Sequential: 建立一個線性的模型堆疊。
  • LSTM(50, activation='relu', input_shape=(timesteps-1, 2)): 添加一個LSTM層。
    • 50個單元(神經元)來學習時間序列模式。
    • relu 激活函數用於增加非線性。
    • input_shape=(timesteps-1, 2):輸入數據的形狀,時間步長為 timesteps-1,每個步長有兩個輸入值(xy 座標)。
  • Dense(2): 添加一個全連接層,輸出兩個值對應於 (x, y) 預測座標。
  • model.compile: 編譯模型,設置優化器為 adam,損失函數為 mse(均方誤差)。

6. 訓練模型

model.fit(X_train, y_train, epochs=200, verbose=1)
  • model.fit: 訓練模型,epochs=200 表示訓練200個週期。

7. 測試模型

X_test, y_test = generate_data(timesteps, 10)
predictions = model.predict(X_test)
  • 生成測試數據,並使用訓練好的模型進行預測。

8. 顯示結果

for i in range(10):
    print(f"真實座標: {y_test[i]}, 預測座標: {predictions[i]}")
  • 輸出測試數據的真實座標與模型預測座標,以比較預測效果。

總結

  • 這段程式碼模擬了一些座標數據,並使用LSTM模型來預測座標的未來位置。
  • 可以用這個框架來應用於實際的斑馬魚座標數據,透過調整參數和數據處理方式來提升模型的預測能力。

上一篇
Day 9Lstm預測兩隻斑馬魚行為分析
下一篇
day 11 yolo 辨識養殖豬隻系統
系列文
基於人工智慧與深度學習對斑馬魚做行為分析30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言